DeepMind/Reverb usage

Following the previous post, I write about DeepMind/Reverb, an experience replay framework.

When reading source code, I found Reverb has other ways of adding and sampling transition data.

Reverb has two clients reverb.Client for prototype development and reverb.TFClient for real training. The TFClient is designed to work inside TensorFlow graph.

Add

reverb.Client.insert

The simplest addition method is reverb.Client.insert method. This stores one transition. Every call of this method, client’s internal buffer is flushed.

import reverb

table_name = "ReplayBuffer"
client = reverb.Client(f"localhost:{server.port}")

client.insert([obs,act,rew,next_obs,done],priorities={table_name: priority})

reverb.Client.writer

The second addition method is creating reverb.Writer by reverb.Client.writer. This method is internally used in reverb.Cilent.insert.

First, append transion(s) to reverb.Writer, then create_item. The item can be a single transition or multiple step transitions like a whole trajectory. When close reverb.Writer, the client internal buffer is flushed, so that you can use reverb.Writer as context manager.

import reverb

table_name = "ReplayBuffer"
client = reverb.Client(f"localhost:{server.port}")

with client.writer(max_sequence_length=3) as writer:
    writer.append([obs,act,rew,next_obs,done])
    writer.create_item(num_timesteps=1,priorities={table_name: priority})

    writer.append([obs,act,rew,next_obs,done])
    writer.append([obs,act,rew,next_obs,done])
    writer.append([obs,act,rew,next_obs,done])
    writer.create_item(num_timesteps=3,priorities={table_name: priority}) # 3 steps are stored ad one single item
# When exiting with-block, the internal buffer is flushed.

reverb.TFClient.insert

The last method reverb.TFClient.insert is similar to reverb.Client.insert, except that its arguments are tf.Tensor

import reverb
import tensorflow as tf

table_name = "ReplayBuffer"
tf_client = reverb.TFClient(f"localhost:{server.port}")

tf_client.insert([tf.constant(obs),
                  tf.constant(act),
                  tf.constant(rew),
                  tf.constant(next_obs),
                  tf.constant(done)],
                 tables=tf.constant([table_name]),
                 priorities=tf.constant([priority,dtype=tf.float64]))

There are two important points. tables argument must be tf.Tensor of str with rank 1. priorities argument must be tf.Tensor of float64 with rank 1.

Sample

reverb.Client.sample

The simplest sampling method is reverb.Client.sample, which returns generator of reverb.replay_sample.ReplaySample (aka. named tuple with info and data).

batch_size = 32

client.sample(table_name,num_samples=batch_size)

reverb.TFClient.sample

The second sampling method is reverb.TFClient.sample. This method requires output dtypes, and does not support batch sampling. The return type is reverb.replay_sample.ReplaySample

dtypes = [tf.float64,tf.float64,tf.float64,tf.float64,tf.float64]

tf_client.sample(table_name,dtypes)

reverb.TFClient.dataset

The last method is reverb.TFClient.dataset, which returns reverb.ReplayDataset derived from tf.data.Dataset.

The ReplayDataset can be used like generator and internally fetch transitions from replay buffer server with proper timing.

This is the preferred method in large scale distributed reinforcement learning.

dtypes = [tf.float64,tf.float64,tf.float64,tf.float64,tf.float64]
shapes = [4,1,1,4,1]

dataset = tf_client(tf.constant([table_name]),dtypes,shapes)

There is a point. shapes argument cannot accept 0 in its elements, so that when you add transitions, all the tf.Tensor have at least rank 1.

Update Priorities

reverb.Client.mutate_priorities

client.mutate_priorities(table_name,update={key: priority})

reverb.TFClient.update_priorities

tf_client.update_priorities(tf.constant([table_name]),keys=key, priorities=priority})
Avatar
Hiroyuki Yamada

My research interests include machine learning, cloud native conputing.

Related